{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8e1bbb5c",
   "metadata": {},
   "source": [
    "## 基于MindNLP的Bert模型应用开发\n",
    "我们使用 Hugging Face 提供的预训练模型 `'huggingface-course/bert-finetuned-ner'` 来进行命名实体识别（NER）任务。此模型已经在特定的 NER 数据集上进行微调，能够自动识别文本中的实体类型，包括人物（PER）、地理位置（GEO）、组织（ORG）等。模型所使用的数据集是 Kaggle 的公开 NER 数据集，其中包含丰富的标签信息，帮助模型准确地理解和标记文本中的不同实体。\n",
    "\n",
    "在模型推理过程中，给定的输入句子会被分词并标注，每个单词都带有对应的实体标签。例如，给定句子 `\"Huawei is a company based in Shen Zhen, but it also has employees like Eric Zhang working in Singapore\"`，模型输出结果如下：\n",
    "\n",
    "- **原始句子分词**：`['Huawei', 'is', 'a', 'company', 'based', 'in', 'Shen', 'Zhen,', 'but', 'it', 'also', 'has', 'employees', 'like', 'Eric', 'Zhang', 'working', 'in', 'Singapore']`\n",
    "- **预测标签**：`['B-org', 'O', 'O', 'O', 'O', 'O', 'B-geo', 'I-geo', 'O', 'O', 'O', 'O', 'O', 'O', 'B-per', 'I-per', 'O', 'O', 'B-geo']`\n",
    "\n",
    "在这个示例中，模型成功地将 `\"Huawei\"` 标记为组织（`B-org`），将 `\"Shen Zhen\"` 和 `\"Singapore\"` 标记为地理位置（`B-geo` 和 `I-geo`），并将 `\"Eric Zhang\"` 标记为人物（`B-per` 和 `I-per`）。其他非实体词则被标记为 `O`，表示它们不是实体的一部分。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ec32e22",
   "metadata": {},
   "source": [
    "### **环境配置：**\n",
    "\n",
    "1. MindSpore 2.3.1\n",
    "2. Mindnlp 0.4.0\n",
    "3. Python 3.9.0\n",
    "4. kagglehub\n",
    "5. pandas\n",
    "6. numpy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "152f3321-6c18-4cca-bce9-d8c16a113064",
   "metadata": {},
   "source": [
    "**使用华为云 ModelArts 作为AI平台**\n",
    "\n",
    "也可以在华为云 ModelArts上 直接跑本项目Notebook: https://pangu.huaweicloud.com/gallery/asset-detail.html?id=e12870f6-f88f-4011-bc8e-f99ce0a7e458"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01698a31-ab6f-448f-b23e-ed0c9ff49489",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Name: mindspore\n",
      "Version: 2.3.1\n",
      "Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.\n",
      "Home-page: https://www.mindspore.cn\n",
      "Author: The MindSpore Authors\n",
      "Author-email: contact@mindspore.cn\n",
      "License: Apache 2.0\n",
      "Location: /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages\n",
      "Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy\n",
      "Required-by: mindnlp\n",
      "Name: mindnlp\n",
      "Version: 0.4.0\n",
      "Summary: An open source natural language processing research tool box. Git version: [sha1]:5b4dad33, [branch]: (HEAD -> master, ms/master)\n",
      "Home-page: https://github.com/mindlab-ai/mindnlp/tree/master/\n",
      "Author: MindSpore Team\n",
      "Author-email: \n",
      "License: Apache 2.0\n",
      "Location: /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages\n",
      "Requires: addict, datasets, evaluate, jieba, mindspore, ml-dtypes, pillow, pyctcdecode, pytest, regex, requests, safetensors, sentencepiece, tokenizers, tqdm\n",
      "Required-by: \n"
     ]
    }
   ],
   "source": [
    "!pip show mindspore\n",
    "!pip show mindnlp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fcfb5f5",
   "metadata": {},
   "source": [
    "### 其他所需库安装"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3a3a53c2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: http://repo.myhuaweicloud.com/repository/pypi/simple\n",
      "Collecting kagglehub\n",
      "  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/c4/d1/4ab25019a168f5c414202f124d156e11ac79f07845d67288929311f1b1b2/kagglehub-0.3.3-py3-none-any.whl (42 kB)\n",
      "Requirement already satisfied: packaging in /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages (from kagglehub) (24.1)\n",
      "Requirement already satisfied: requests in /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages (from kagglehub) (2.32.3)\n",
      "Requirement already satisfied: tqdm in /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages (from kagglehub) (4.66.5)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages (from requests->kagglehub) (3.4.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages (from requests->kagglehub) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages (from requests->kagglehub) (2.2.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages (from requests->kagglehub) (2024.8.30)\n",
      "Installing collected packages: kagglehub\n",
      "Successfully installed kagglehub-0.3.3\n",
      "Note: you may need to restart the kernel to use updated packages.\n",
      "Looking in indexes: http://repo.myhuaweicloud.com/repository/pypi/simple\n",
      "Requirement already satisfied: pandas in /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages (2.2.3)\n",
      "Requirement already satisfied: numpy>=1.22.4 in /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages (from pandas) (1.26.4)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages (from pandas) (2.9.0.post0)\n",
      "Requirement already satisfied: pytz>=2020.1 in /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages (from pandas) (2024.2)\n",
      "Requirement already satisfied: tzdata>=2022.7 in /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages (from pandas) (2024.2)\n",
      "Requirement already satisfied: six>=1.5 in /home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install kagglehub\n",
    "%pip install pandas"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecff84e0",
   "metadata": {},
   "source": [
    "### 2 Bert 模型\n",
    "原文地址：https://arxiv.org/abs/1810.04805\n",
    "\n",
    "原TensorFlow代码：https://github.com/google-research/bert  \n",
    "原Pytorch代码：https://github.com/codertimo/BERT-pytorch "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93381805",
   "metadata": {},
   "source": [
    "#### 2.1 Bert 模型\n",
    "\n",
    "BERT（Bidirectional Encoder Representations from Transformers）是一种由谷歌提出的预训练语言模型，旨在通过双向 Transformer 编码器来捕获语言的上下文信息。与传统的单向语言模型不同，BERT 通过双向（从左到右和从右到左）同时处理文本，从而理解每个词在整个句子中的真正含义。这种方法能够更好地理解句子中的词义和结构，提高对自然语言的理解能力。BERT 模型的预训练采用了两个主要任务：**掩码语言模型（Masked Language Model，MLM）** 和 **下一个句子预测（Next Sentence Prediction，NSP）**。在 MLM 任务中，BERT 会随机掩盖输入中的部分词，并通过模型预测这些被掩盖的词，从而学习词汇和上下文之间的关系。而 NSP 任务则训练模型判断两个句子是否紧密关联，帮助模型理解句子间的逻辑关系。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6580d044",
   "metadata": {},
   "source": [
    "#### 2.2 导入模型到mindnlp\n",
    "可以从中 `dir(mindnlp.transformers)` 查看和导入到关于Bert的包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "02b4aca1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"HF_ENDPOINT\"] = \"https://hf-mirror.com\"       #huggingface镜像      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "48ad7d2b-3465-487d-b206-3a61bc823d62",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Building prefix dict from the default dictionary ...\n",
      "Dumping model to file cache /tmp/jieba.cache\n",
      "Loading model cost 0.626 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "import mindnlp\n",
    "# dir(mindnlp.transformers)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9e1e5599",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ma-user/anaconda3/envs/mindnlp/lib/python3.9/site-packages/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted, and will be then set to `False` by default. \n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from mindspore import Tensor\n",
    "from mindnlp.transformers import BertTokenizerFast,BertForTokenClassification\n",
    "\n",
    "# 初始化分词器和模型\n",
    "tokenizer = BertTokenizerFast.from_pretrained('huggingface-course/bert-finetuned-ner')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6156e180",
   "metadata": {},
   "source": [
    "#### 2.4 加载数据\n",
    "使用的是Kaggle 的 [`ner-dataset`](https://www.kaggle.com/datasets/namanj27/ner-dataset)，这是一个用于命名实体识别（NER）任务的公开数据集，主要用于训练和评估模型在识别文本中实体的能力。数据集包含大量已标注的文本，每个单词或标记（token）都带有相应的标签，表明它在文本中的角色，如一个示例行标签的含义如下：\n",
    "\n",
    "### **BIO 标注格式**\n",
    "- **B-** 表示实体的开始（Begin）。\n",
    "- **I-** 表示实体的内部（Inside）。\n",
    "- **O** 表示非实体（Outside）。\n",
    "\n",
    "### **标签含义及示例**\n",
    "\n",
    "| 标签       | 含义                                    | 示例句子                         | 解释                                             |\n",
    "|------------|-----------------------------------------|----------------------------------|--------------------------------------------------|\n",
    "| **O**      | Outside，非实体                         | \"The sky is blue.\"               | \"sky\" 和 \"blue\" 都是 `O`，表示非实体词。          |\n",
    "| **B-per**  | Begin-person，人物名称的开始            | \"Albert Einstein was a physicist.\" | \"Albert\" 被标记为 `B-per`，表示人物的开始。        |\n",
    "| **I-per**  | Inside-person，人物名称的内部词         | \"Albert Einstein\"                | \"Einstein\" 被标记为 `I-per`，表示人物的内部部分。   |\n",
    "| **B-org**  | Begin-organization，组织名称的开始      | \"Google is a tech company.\"      | \"Google\" 被标记为 `B-org`，表示组织的开始。       |\n",
    "| **I-org**  | Inside-organization，组织名称的内部词   | \"University of Oxford\"           | \"of\" 和 \"Oxford\" 被标记为 `I-org`，表示组织名称的内部部分。 |\n",
    "| **B-geo**  | Begin-geographical location，地理位置的开始 | \"Paris is beautiful.\"            | \"Paris\" 被标记为 `B-geo`，表示一个地理位置的开始。 |\n",
    "| **I-geo**  | Inside-geographical location，地理位置的内部词 | \"San Francisco Bay\"              | \"Francisco\" 和 \"Bay\" 被标记为 `I-geo`，表示地理位置内部部分。 |\n",
    "| **B-misc** | Begin-miscellaneous，其他实体的开始      | \"The Olympics are held every four years.\" | \"Olympics\" 被标记为 `B-misc`，表示其他实体的开始。 |\n",
    "| **I-misc** | Inside-miscellaneous，其他实体的内部词   | \"The 2020 Summer Olympics\"       | \"Summer\" 和 \"Olympics\" 被标记为 `I-misc`，表示其他实体的内部部分。 |\n",
    "\n",
    "### **具体示例解析**\n",
    "\n",
    "给定以下句子：\n",
    "\n",
    "> \"Albert Einstein was born in Ulm, Germany and worked at the University of Zurich.\"\n",
    "\n",
    "模型会给出的标签为：\n",
    "\n",
    "| 单词       | 标签     | 解释                                  |\n",
    "|------------|----------|---------------------------------------|\n",
    "| Albert     | B-per    | `B-per` 表示人物名称的开始。           |\n",
    "| Einstein   | I-per    | `I-per` 表示人物名称的内部词。         |\n",
    "| was        | O        | `O` 表示非实体。                      |\n",
    "| born       | O        | `O` 表示非实体。                      |\n",
    "| in         | O        | `O` 表示非实体。                      |\n",
    "| Ulm        | B-geo    | `B-geo` 表示地理位置的开始。          |\n",
    "| ,          | O        | `O` 表示非实体标点符号。              |\n",
    "| Germany    | B-geo    | `B-geo` 表示地理位置的开始。          |\n",
    "| and        | O        | `O` 表示非实体。                      |\n",
    "| worked     | O        | `O` 表示非实体。                      |\n",
    "| at         | O        | `O` 表示非实体。                      |\n",
    "| the        | O        | `O` 表示非实体。                      |\n",
    "| University | B-org    | `B-org` 表示组织名称的开始。           |\n",
    "| of         | I-org    | `I-org` 表示组织名称的内部词。         |\n",
    "| Zurich.    | I-org    | `I-org` 表示组织名称的内部词。         |\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a7b5aca3-6d7d-4a39-b4d5-6ceca3b97e64",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading from https://www.kaggle.com/api/v1/datasets/download/namanj27/ner-dataset?dataset_version_number=2...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3.17M/3.17M [00:02<00:00, 1.65MB/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting files...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Path to dataset files: /home/ma-user/.cache/kagglehub/datasets/namanj27/ner-dataset/versions/2\n"
     ]
    }
   ],
   "source": [
    "import kagglehub\n",
    "\n",
    "# 指定下载路径为当前目录\n",
    "path = kagglehub.dataset_download(\"namanj27/ner-dataset\")\n",
    "print(\"Path to dataset files:\", path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "66722443-3077-4c0d-a443-71193c555f4e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Sentence #</th>\n",
       "      <th>Word</th>\n",
       "      <th>POS</th>\n",
       "      <th>Tag</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Sentence: 1</td>\n",
       "      <td>Thousands</td>\n",
       "      <td>NNS</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>of</td>\n",
       "      <td>IN</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>NaN</td>\n",
       "      <td>demonstrators</td>\n",
       "      <td>NNS</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>NaN</td>\n",
       "      <td>have</td>\n",
       "      <td>VBP</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>NaN</td>\n",
       "      <td>marched</td>\n",
       "      <td>VBN</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    Sentence #           Word  POS Tag\n",
       "0  Sentence: 1      Thousands  NNS   O\n",
       "1          NaN             of   IN   O\n",
       "2          NaN  demonstrators  NNS   O\n",
       "3          NaN           have  VBP   O\n",
       "4          NaN        marched  VBN   O"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "data = pd.read_csv(path+\"/ner_datasetreference.csv\", encoding='unicode_escape')\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6e8698c",
   "metadata": {},
   "source": [
    "为模型准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c66e4202",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of tags: 17\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Tag\n",
       "O        887908\n",
       "B-geo     37644\n",
       "B-tim     20333\n",
       "B-org     20143\n",
       "I-per     17251\n",
       "B-per     16990\n",
       "I-org     16784\n",
       "B-gpe     15870\n",
       "I-geo      7414\n",
       "I-tim      6528\n",
       "B-art       402\n",
       "B-eve       308\n",
       "I-art       297\n",
       "I-eve       253\n",
       "B-nat       201\n",
       "I-gpe       198\n",
       "I-nat        51\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"Number of tags: {}\".format(len(data.Tag.unique())))\n",
    "frequencies = data.Tag.value_counts()\n",
    "frequencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0d89ef63",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2269/4071460696.py:3: FutureWarning: DataFrame.fillna with 'method' is deprecated and will raise in a future version. Use obj.ffill() or obj.bfill() instead.\n",
      "  data = data.fillna(method='ffill')\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Sentence #</th>\n",
       "      <th>Word</th>\n",
       "      <th>POS</th>\n",
       "      <th>Tag</th>\n",
       "      <th>sentence</th>\n",
       "      <th>word_labels</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Sentence: 1</td>\n",
       "      <td>Thousands</td>\n",
       "      <td>NNS</td>\n",
       "      <td>O</td>\n",
       "      <td>Thousands of demonstrators have marched throug...</td>\n",
       "      <td>O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Sentence: 1</td>\n",
       "      <td>of</td>\n",
       "      <td>IN</td>\n",
       "      <td>O</td>\n",
       "      <td>Thousands of demonstrators have marched throug...</td>\n",
       "      <td>O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Sentence: 1</td>\n",
       "      <td>demonstrators</td>\n",
       "      <td>NNS</td>\n",
       "      <td>O</td>\n",
       "      <td>Thousands of demonstrators have marched throug...</td>\n",
       "      <td>O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Sentence: 1</td>\n",
       "      <td>have</td>\n",
       "      <td>VBP</td>\n",
       "      <td>O</td>\n",
       "      <td>Thousands of demonstrators have marched throug...</td>\n",
       "      <td>O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Sentence: 1</td>\n",
       "      <td>marched</td>\n",
       "      <td>VBN</td>\n",
       "      <td>O</td>\n",
       "      <td>Thousands of demonstrators have marched throug...</td>\n",
       "      <td>O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    Sentence #           Word  POS Tag  \\\n",
       "0  Sentence: 1      Thousands  NNS   O   \n",
       "1  Sentence: 1             of   IN   O   \n",
       "2  Sentence: 1  demonstrators  NNS   O   \n",
       "3  Sentence: 1           have  VBP   O   \n",
       "4  Sentence: 1        marched  VBN   O   \n",
       "\n",
       "                                            sentence  \\\n",
       "0  Thousands of demonstrators have marched throug...   \n",
       "1  Thousands of demonstrators have marched throug...   \n",
       "2  Thousands of demonstrators have marched throug...   \n",
       "3  Thousands of demonstrators have marched throug...   \n",
       "4  Thousands of demonstrators have marched throug...   \n",
       "\n",
       "                                         word_labels  \n",
       "0  O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-...  \n",
       "1  O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-...  \n",
       "2  O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-...  \n",
       "3  O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-...  \n",
       "4  O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-...  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "entities_to_remove = [\"B-art\", \"I-art\", \"B-eve\", \"I-eve\", \"B-nat\", \"I-nat\"]\n",
    "data = data[~data.Tag.isin(entities_to_remove)]\n",
    "data = data.fillna(method='ffill')\n",
    "# 创建一个名为 \"sentence\" 的新列，将单词按句子进行分组\n",
    "# 创建一个名为 \"word_labels\" 的新列，将标签按句子进行分组\n",
    "data['sentence'] = data[['Sentence #','Word','Tag']].groupby(['Sentence #'])['Word'].transform(lambda x: ' '.join(x))\n",
    "data['word_labels'] = data[['Sentence #','Word','Tag']].groupby(['Sentence #'])['Tag'].transform(lambda x: ','.join(x))\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ce6fda5-ccce-4917-9b15-105b033a81d7",
   "metadata": {},
   "source": [
    "创建两个字典：一个将每个标签映射到索引，另一个将索引映射到对应的标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "347279f1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'O': 0,\n",
       " 'B-geo': 1,\n",
       " 'B-gpe': 2,\n",
       " 'B-per': 3,\n",
       " 'I-geo': 4,\n",
       " 'B-org': 5,\n",
       " 'I-org': 6,\n",
       " 'B-tim': 7,\n",
       " 'I-per': 8,\n",
       " 'I-gpe': 9,\n",
       " 'I-tim': 10}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "label2id = {k: v for v, k in enumerate(data.Tag.unique())}\n",
    "id2label = {v: k for v, k in enumerate(data.Tag.unique())}\n",
    "label2id"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5f336b0",
   "metadata": {},
   "source": [
    "按引号查看例子和标准Tag"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0f9d4f65",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentence</th>\n",
       "      <th>word_labels</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Thousands of demonstrators have marched throug...</td>\n",
       "      <td>O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Families of soldiers killed in the conflict jo...</td>\n",
       "      <td>O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,B-per,O,O,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>They marched from the Houses of Parliament to ...</td>\n",
       "      <td>O,O,O,O,O,O,O,O,O,O,O,B-geo,I-geo,O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Police put the number of marchers at 10,000 wh...</td>\n",
       "      <td>O,O,O,O,O,O,O,O,O,O,O,O,O,O,O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>The protest comes on the eve of the annual con...</td>\n",
       "      <td>O,O,O,O,O,O,O,O,O,O,O,B-geo,O,O,B-org,I-org,O,...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            sentence  \\\n",
       "0  Thousands of demonstrators have marched throug...   \n",
       "1  Families of soldiers killed in the conflict jo...   \n",
       "2  They marched from the Houses of Parliament to ...   \n",
       "3  Police put the number of marchers at 10,000 wh...   \n",
       "4  The protest comes on the eve of the annual con...   \n",
       "\n",
       "                                         word_labels  \n",
       "0  O,O,O,O,O,O,B-geo,O,O,O,O,O,B-geo,O,O,O,O,O,B-...  \n",
       "1  O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,B-per,O,O,...  \n",
       "2                O,O,O,O,O,O,O,O,O,O,O,B-geo,I-geo,O  \n",
       "3                      O,O,O,O,O,O,O,O,O,O,O,O,O,O,O  \n",
       "4  O,O,O,O,O,O,O,O,O,O,O,B-geo,O,O,B-org,I-org,O,...  "
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = data[[\"sentence\", \"word_labels\"]].drop_duplicates().reset_index(drop=True)\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "025aeb83",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Families of soldiers killed in the conflict joined the protesters who carried banners with such slogans as \" Bush Number One Terrorist \" and \" Stop the Bombings . \"'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.iloc[1].sentence"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2fc85f5",
   "metadata": {},
   "source": [
    "##### 2.5 模型推理\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "178deedb-d175-448d-9f79-19c8589311bf",
   "metadata": {},
   "source": [
    "定义模型必要的变量和初始化模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a0c1edd5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[MS_ALLOC_CONF]Runtime config:  enable_vmm:True  vmm_align_size:2MB\n"
     ]
    }
   ],
   "source": [
    "label_list = [\"O\", \"B-per\", \"I-per\", \"B-org\", \"I-org\", \"B-geo\", \"I-geo\", \"B-misc\", \"I-misc\"]\n",
    "labels_to_ids = {label: idx for idx, label in enumerate(label_list)}\n",
    "num_labels = len(labels_to_ids)  \n",
    "ids_to_labels = {v: k for k, v in labels_to_ids.items()}\n",
    "model = BertForTokenClassification.from_pretrained('huggingface-course/bert-finetuned-ner', num_labels=num_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "680f4961-31e4-4363-abcc-36649e34dca9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ner_prediction(sentence, model, tokenizer, ids_to_labels, max_length=128):\n",
    "    \"\"\"\n",
    "    使用给定的 BERT 模型和分词器对输入句子进行命名实体识别。\n",
    "\n",
    "    参数:\n",
    "        sentence (str): 要进行命名实体识别的句子。\n",
    "        model: 已加载的 BERT 模型，用于命名实体识别。\n",
    "        tokenizer: 用于分词的 BERT 分词器。\n",
    "        ids_to_labels (dict): 标签 ID 到标签名称的映射字典。\n",
    "        max_length (int): 句子的最大长度，默认为 128。\n",
    "\n",
    "    返回:\n",
    "        list: 原始句子分词列表。\n",
    "        list: 每个词对应的标签预测结果。\n",
    "    \"\"\"\n",
    "    # 对句子进行分词\n",
    "    inputs = tokenizer(\n",
    "        sentence.split(),\n",
    "        is_split_into_words=True,\n",
    "        return_offsets_mapping=True,\n",
    "        padding='max_length',\n",
    "        truncation=True,\n",
    "        max_length=max_length,\n",
    "        return_tensors=None\n",
    "    )\n",
    "\n",
    "    # 将输入数据转换为 Tensors\n",
    "    input_ids = np.array([inputs[\"input_ids\"]], dtype=np.int32)\n",
    "    attention_mask = np.array([inputs[\"attention_mask\"]], dtype=np.int32)\n",
    "    ids = Tensor(input_ids)\n",
    "    mask = Tensor(attention_mask)\n",
    "\n",
    "    # 模型前向传播\n",
    "    outputs = model(input_ids=ids, attention_mask=mask)\n",
    "    logits = outputs[0]\n",
    "\n",
    "    # 获取预测结果\n",
    "    active_logits = logits.reshape(-1, len(ids_to_labels))\n",
    "    flattened_predictions = active_logits.argmax(axis=1)\n",
    "\n",
    "    # 将预测的标签 ID 映射回标签名称\n",
    "    tokens = tokenizer.convert_ids_to_tokens(inputs[\"input_ids\"])\n",
    "    token_predictions = [ids_to_labels[int(id)] for id in flattened_predictions.asnumpy()]\n",
    "\n",
    "    # 提取首个子词的预测结果\n",
    "    wp_preds = list(zip(tokens, token_predictions))\n",
    "    offsets = inputs[\"offset_mapping\"]\n",
    "\n",
    "    prediction = []\n",
    "    for token_pred, mapping in zip(wp_preds, offsets):\n",
    "        if mapping[0] == 0 and mapping[1] != 0:\n",
    "            prediction.append(token_pred[1])\n",
    "\n",
    "    # 输出结果\n",
    "    return sentence.split(), prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5659ec04",
   "metadata": {},
   "source": [
    "使用推理示例1-kaggle数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ab83fb64",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "原始句子分词： ['Bedfordshire', 'police', 'said', 'Tuesday', 'that', 'Omar', 'Khayam', 'was', 'arrested', 'in', 'Bedford', 'for', 'breaching', 'the', 'conditions', 'of', 'his', 'parole', '.']\n",
      "预测标签： ['B-geo', 'O', 'O', 'O', 'O', 'B-per', 'I-per', 'O', 'O', 'O', 'B-geo', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n"
     ]
    }
   ],
   "source": [
    "# 输入句子\n",
    "sentence = data.iloc[41].sentence\n",
    "\n",
    "# 调用预测函数\n",
    "words, predictions = ner_prediction(sentence, model, tokenizer, ids_to_labels)\n",
    "\n",
    "# 打印输出\n",
    "print(\"原始句子分词：\", words)\n",
    "print(\"预测标签：\", predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a460565c",
   "metadata": {},
   "source": [
    "使用推理示例2-tag介绍例子和其他自定义例子"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "29e3a2a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "原始句子分词： ['Albert', 'Einstein', 'was', 'born', 'in', 'Ulm,', 'Germany', 'and', 'worked', 'at', 'the', 'University', 'of', 'Zurich.']\n",
      "预测标签： ['B-per', 'I-per', 'O', 'O', 'O', 'B-geo', 'B-geo', 'O', 'O', 'O', 'O', 'B-org', 'I-org', 'I-org']\n"
     ]
    }
   ],
   "source": [
    "# 输入句子\n",
    "sentence = \"Albert Einstein was born in Ulm, Germany and worked at the University of Zurich.\"\n",
    "\n",
    "# 调用预测函数\n",
    "words, predictions = ner_prediction(sentence, model, tokenizer, ids_to_labels)\n",
    "\n",
    "# 打印输出\n",
    "print(\"原始句子分词：\", words)\n",
    "print(\"预测标签：\", predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "4df5c73a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "原始句子分词： ['Huawei', 'is', 'a', 'company', 'based', 'in', 'Shen', 'Zhen,', 'but', 'it', 'also', 'has', 'employees', 'like', 'Eric', 'Zhang', 'working', 'in', 'Singapore']\n",
      "预测标签： ['B-org', 'O', 'O', 'O', 'O', 'O', 'B-geo', 'I-geo', 'O', 'O', 'O', 'O', 'O', 'O', 'B-per', 'I-per', 'O', 'O', 'B-geo']\n"
     ]
    }
   ],
   "source": [
    "# 输入句子\n",
    "sentence = \"Huawei is a company based in Shen Zhen, but it also has employees like Eric Zhang working in Singapore\"\n",
    "\n",
    "# 调用预测函数\n",
    "words, predictions = ner_prediction(sentence, model, tokenizer, ids_to_labels)\n",
    "\n",
    "# 打印输出\n",
    "print(\"原始句子分词：\", words)\n",
    "print(\"预测标签：\", predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "fbf855d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "原始句子分词： ['@HuggingFace', 'is', 'a', 'company', 'based', 'in', 'New', 'York,', 'but', 'it', 'also', 'has', 'employees', 'working', 'in', 'Paris']\n",
      "预测标签： ['B-org', 'O', 'O', 'O', 'O', 'O', 'B-geo', 'I-geo', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-geo']\n"
     ]
    }
   ],
   "source": [
    "# 输入句子\n",
    "sentence = \"@HuggingFace is a company based in New York, but it also has employees working in Paris\"\n",
    "\n",
    "# 调用预测函数\n",
    "words, predictions = ner_prediction(sentence, model, tokenizer, ids_to_labels)\n",
    "\n",
    "# 打印输出\n",
    "print(\"原始句子分词：\", words)\n",
    "print(\"预测标签：\", predictions)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mindnlp",
   "language": "python",
   "name": "mindnlp"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
